{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 用于预训练词嵌入的数据集\n",
    ":label:`sec_word2vec_data`\n",
    "\n",
    "现在我们已经了解了word2vec模型的技术细节和大致的训练方法，让我们来看看它们的实现。具体地说，我们将以 :numref:`sec_word2vec`的跳元模型和 :numref:`sec_approx_train`的负采样为例。在本节中，我们从用于预训练词嵌入模型的数据集开始：数据的原始格式将被转换为可以在训练期间迭代的小批量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/PlotUtils.java\n",
    "\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/Animator.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/timemachine/Vocab.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.stream.*;\n",
    "import org.apache.commons.math3.distribution.EnumeratedDistribution;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 正在读取数据集\n",
    "\n",
    "我们在这里使用的数据集是[Penn Tree Bank（PTB）](https://catalog.ldc.upenn.edu/LDC99T42)。该语料库取自“华尔街日报”的文章，分为训练集、验证集和测试集。在原始格式中，文本文件的每一行表示由空格分隔的一句话。在这里，我们将每个单词视为一个词元。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String[][] readPTB() throws IOException {\n",
    "    String ptbURL = \"http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip\";\n",
    "    InputStream input = new URL(ptbURL).openStream();\n",
    "    ZipUtils.unzip(input, Paths.get(\"./\"));\n",
    "\n",
    "    ArrayList<String> lines = new ArrayList<>();\n",
    "    File file = new File(\"./ptb/ptb.train.txt\");\n",
    "    Scanner myReader = new Scanner(file);\n",
    "    while (myReader.hasNextLine()) {\n",
    "        lines.add(myReader.nextLine());\n",
    "    }\n",
    "    String[][] tokens = new String[lines.size()][];\n",
    "    for (int i = 0; i < lines.size(); i++) {\n",
    "        tokens[i] = lines.get(i).trim().split(\" \");\n",
    "    }\n",
    "    return tokens;\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[][] sentences = readPTB();\n",
    "System.out.println(\"# sentences: \" + sentences.length);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在读取训练集之后，我们为语料库构建了一个词表，其中出现次数少于10次的任何单词都将由“&lt;unk&gt;”词元替换。请注意，原始数据集还包含表示稀有（未知）单词的“&lt;unk&gt;”词元。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Vocab vocab = new Vocab(sentences, 10, new String[] {});\n",
    "System.out.println(vocab.length());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 下采样\n",
    "\n",
    "文本数据通常有“the”、“a”和“in”等高频词：它们在非常大的语料库中甚至可能出现数十亿次。然而，这些词经常在上下文窗口中与许多不同的词共同出现，提供的有用信息很少。例如，考虑上下文窗口中的词“chip”：直观地说，它与低频单词“intel”的共现比与高频单词“a”的共现在训练中更有用。此外，大量（高频）单词的训练速度很慢。因此，当训练词嵌入模型时，可以对高频单词进行*下采样* :cite:`Mikolov.Sutskever.Chen.ea.2013`。具体地说，数据集中的每个词$w_i$将有概率地被丢弃\n",
    "\n",
    "$$ P(w_i) = \\max\\left(1 - \\sqrt{\\frac{t}{f(w_i)}}, 0\\right),$$\n",
    "\n",
    "其中$f(w_i)$是$w_i$的词数与数据集中的总词数的比率，常量$t$是超参数（在实验中为$10^{-4}$）。我们可以看到，只有当相对比率$f(w_i) > t$时，（高频）词$w_i$才能被丢弃，且该词的相对比率越高，被丢弃的概率就越大。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static boolean keep(String token, LinkedHashMap<?, Integer> counter, int numTokens) {\n",
    "    // Return True if to keep this token during subsampling\n",
    "    return new Random().nextFloat() < Math.sqrt(1e-4 / counter.get(token) * numTokens);\n",
    "}\n",
    "\n",
    "public static String[][] subSampling(String[][] sentences, Vocab vocab) {\n",
    "    for (int i = 0; i < sentences.length; i++) {\n",
    "        for (int j = 0; j < sentences[i].length; j++) {\n",
    "            sentences[i][j] = vocab.idxToToken.get(vocab.getIdx(sentences[i][j]));\n",
    "        }\n",
    "    }\n",
    "    // Count the frequency for each word\n",
    "    LinkedHashMap<?, Integer> counter = vocab.countCorpus2D(sentences);\n",
    "    int numTokens = 0;\n",
    "    for (Integer value : counter.values()) {\n",
    "        numTokens += value;\n",
    "    }\n",
    "\n",
    "    // Now do the subsampling\n",
    "    String[][] output = new String[sentences.length][];\n",
    "    for (int i = 0; i < sentences.length; i++) {\n",
    "        ArrayList<String> tks = new ArrayList<>();\n",
    "        for (int j = 0; j < sentences[i].length; j++) {\n",
    "            String tk = sentences[i][j];\n",
    "            if (keep(sentences[i][j], counter, numTokens)) {\n",
    "                tks.add(tk);\n",
    "            }\n",
    "        }\n",
    "        output[i] = tks.toArray(new String[tks.size()]);\n",
    "    }\n",
    "\n",
    "    return output;\n",
    "}\n",
    "\n",
    "String[][] subsampled = subSampling(sentences, vocab);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的代码片段绘制了下采样前后每句话的词元数量的直方图。正如预期的那样，下采样通过删除高频词来显著缩短句子，这将使训练加速。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "double[] y1 = new double[sentences.length];\n",
    "for (int i = 0; i < sentences.length; i++) y1[i] = sentences[i].length;\n",
    "double[] y2 = new double[subsampled.length];\n",
    "for (int i = 0; i < subsampled.length; i++) y2[i] = subsampled[i].length;\n",
    "\n",
    "HistogramTrace trace1 =\n",
    "        HistogramTrace.builder(y1).opacity(.75).name(\"origin\").nBinsX(20).build();\n",
    "HistogramTrace trace2 =\n",
    "        HistogramTrace.builder(y2).opacity(.75).name(\"subsampled\").nBinsX(20).build();\n",
    "\n",
    "Layout layout =\n",
    "        Layout.builder()\n",
    "                .barMode(Layout.BarMode.GROUP)\n",
    "                .showLegend(true)\n",
    "                .xAxis(Axis.builder().title(\"# tokens per sentence\").build())\n",
    "                .yAxis(Axis.builder().title(\"count\").build())\n",
    "                .build();\n",
    "new Figure(layout, trace1, trace2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于单个词元，高频词“the”的采样率不到1/20。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String compareCounts(String token, String[][] sentences, String[][] subsampled) {\n",
    "    int beforeCount = 0;\n",
    "    for (int i = 0; i < sentences.length; i++) {\n",
    "        for (int j = 0; j < sentences[i].length; j++) {\n",
    "            if (sentences[i][j].equals(token)) beforeCount += 1;\n",
    "        }\n",
    "    }\n",
    "\n",
    "    int afterCount = 0;\n",
    "    for (int i = 0; i < subsampled.length; i++) {\n",
    "        for (int j = 0; j < subsampled[i].length; j++) {\n",
    "            if (subsampled[i][j].equals(token)) afterCount += 1;\n",
    "        }\n",
    "    }\n",
    "\n",
    "    return \"# of \\\"the\\\": before=\" + beforeCount + \", after=\" + afterCount;\n",
    "}\n",
    "\n",
    "System.out.println(compareCounts(\"the\", sentences, subsampled));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "相比之下，低频词“join”则被完全保留。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "System.out.println(compareCounts(\"join\", sentences, subsampled));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在下采样之后，我们将词元映射到它们在语料库中的索引。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Integer[][] corpus = new Integer[subsampled.length][];\n",
    "for (int i = 0; i < subsampled.length; i++) {\n",
    "    corpus[i] = vocab.getIdxs(subsampled[i]);\n",
    "}\n",
    "for (int i = 0; i < 3; i++) {\n",
    "    System.out.println(Arrays.toString(corpus[i]));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 中心词和上下文词的提取\n",
    "\n",
    "下面的`get_centers_and_contexts`函数从`corpus`中提取所有中心词及其上下文词。它随机采样1到`max_window_size`之间的整数作为上下文窗口。对于任一中心词，与其距离不超过采样上下文窗口大小的词为其上下文词。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> getCentersAndContext(\n",
    "        Integer[][] corpus, int maxWindowSize) {\n",
    "    ArrayList<Integer> centers = new ArrayList<>();\n",
    "    ArrayList<ArrayList<Integer>> contexts = new ArrayList<>();\n",
    "\n",
    "    for (Integer[] line : corpus) {\n",
    "        // Each sentence needs at least 2 words to form a \"central target word\n",
    "        // - context word\" pair\n",
    "        if (line.length < 2) {\n",
    "            continue;\n",
    "        }\n",
    "        centers.addAll(Arrays.asList(line));\n",
    "        for (int i = 0; i < line.length; i++) { // Context window centered at i\n",
    "            int windowSize = new Random().nextInt(maxWindowSize - 1) + 1;\n",
    "            List<Integer> indices =\n",
    "                    IntStream.range(\n",
    "                                    Math.max(0, i - windowSize),\n",
    "                                    Math.min(line.length, i + 1 + windowSize))\n",
    "                            .boxed()\n",
    "                            .collect(Collectors.toList());\n",
    "            // Exclude the central target word from the context words\n",
    "            indices.remove(indices.indexOf(i));\n",
    "            ArrayList<Integer> context = new ArrayList<>();\n",
    "            for (Integer idx : indices) {\n",
    "                context.add(line[idx]);\n",
    "            }\n",
    "            contexts.add(context);\n",
    "        }\n",
    "    }\n",
    "    return new Pair<>(centers, contexts);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们创建一个人工数据集，分别包含7个和3个单词的两个句子。设置最大上下文窗口大小为2，并打印所有中心词及其上下文词。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Integer[][] tinyDataset =\n",
    "        new Integer[][] {\n",
    "            IntStream.range(0, 7)\n",
    "                    .boxed()\n",
    "                    .collect(Collectors.toList())\n",
    "                    .toArray(new Integer[] {}),\n",
    "            IntStream.range(7, 10)\n",
    "                    .boxed()\n",
    "                    .collect(Collectors.toList())\n",
    "                    .toArray(new Integer[] {})\n",
    "        };\n",
    "\n",
    "System.out.println(\"dataset \" + Arrays.deepToString(tinyDataset));\n",
    "Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> centerContextPair =\n",
    "        getCentersAndContext(tinyDataset, 2);\n",
    "for (int i = 0; i < centerContextPair.getValue().size(); i++) {\n",
    "    System.out.println(\n",
    "            \"Center \"\n",
    "                    + centerContextPair.getKey().get(i)\n",
    "                    + \" has contexts\"\n",
    "                    + centerContextPair.getValue().get(i));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在PTB数据集上进行训练时，我们将最大上下文窗口大小设置为5。下面提取数据集中的所有中心词及其上下文词。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "centerContextPair = getCentersAndContext(corpus, 5);\n",
    "ArrayList<Integer> allCenters = centerContextPair.getKey();\n",
    "ArrayList<ArrayList<Integer>> allContexts = centerContextPair.getValue();\n",
    "System.out.println(\"中心词-上下文词对”的数量:\" + allCenters.size());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 负采样\n",
    "\n",
    "我们使用负采样进行近似训练。为了根据预定义的分布对噪声词进行采样，我们定义以下`RandomGenerator`类，其中（可能未规范化的）采样分布通过变量`samplingWeights`传递。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class RandomGenerator {\n",
    "    /* Draw a random int in [0, n] according to n sampling weights. */\n",
    "\n",
    "    private List<Integer> population;\n",
    "    private List<Double> samplingWeights;\n",
    "    private List<Integer> candidates;\n",
    "    private List<org.apache.commons.math3.util.Pair<Integer, Double>> pmf;\n",
    "    private int i;\n",
    "\n",
    "    public RandomGenerator(List<Double> samplingWeights) {\n",
    "        this.population =\n",
    "                IntStream.range(0, samplingWeights.size()).boxed().collect(Collectors.toList());\n",
    "        this.samplingWeights = samplingWeights;\n",
    "        this.candidates = new ArrayList<>();\n",
    "        this.i = 0;\n",
    "\n",
    "        this.pmf = new ArrayList<>();\n",
    "        for (int i = 0; i < samplingWeights.size(); i++) {\n",
    "            this.pmf.add(new org.apache.commons.math3.util.Pair(this.population.get(i), this.samplingWeights.get(i).doubleValue()));\n",
    "        }\n",
    "    }\n",
    "\n",
    "    public Integer draw() {\n",
    "        if (this.i == this.candidates.size()) {\n",
    "            this.candidates =\n",
    "                    Arrays.asList((Integer[]) new EnumeratedDistribution(this.pmf).sample(10000, new Integer[] {}));\n",
    "            this.i = 0;\n",
    "        }\n",
    "        this.i += 1;\n",
    "        return this.candidates.get(this.i - 1);\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "例如，我们可以在索引1、2和3中绘制10个随机变量$X$，采样概率为$P(X=1)=2/9, P(X=2)=3/9$和$P(X=3)=4/9$，如下所示。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RandomGenerator generator =\n",
    "        new RandomGenerator(Arrays.asList(new Double[] {2.0, 3.0, 4.0}));\n",
    "Integer[] generatorOutput = new Integer[10];\n",
    "for (int i = 0; i < 10; i++) {\n",
    "    generatorOutput[i] = generator.draw();\n",
    "}\n",
    "System.out.println(Arrays.toString(generatorOutput));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于一对中心词和上下文词，我们随机抽取了`K`个（实验中为5个）噪声词。根据word2vec论文中的建议，将噪声词$w$的采样概率$P(w)$设置为其在字典中的相对频率，其幂为0.75 :cite:`Mikolov.Sutskever.Chen.ea.2013`。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static ArrayList<ArrayList<Integer>> getNegatives(\n",
    "        ArrayList<ArrayList<Integer>> allContexts, Integer[][] corpus, int K) {\n",
    "    LinkedHashMap<?, Integer> counter = Vocab.countCorpus2D(corpus);\n",
    "    ArrayList<Double> samplingWeights = new ArrayList<>();\n",
    "    for (Map.Entry<?, Integer> entry : counter.entrySet()) {\n",
    "        samplingWeights.add(Math.pow(entry.getValue(), .75));\n",
    "    }\n",
    "    ArrayList<ArrayList<Integer>> allNegatives = new ArrayList<>();\n",
    "    RandomGenerator generator = new RandomGenerator(samplingWeights);\n",
    "    for (ArrayList<Integer> contexts : allContexts) {\n",
    "        ArrayList<Integer> negatives = new ArrayList<>();\n",
    "        while (negatives.size() < contexts.size() * K) {\n",
    "            Integer neg = generator.draw();\n",
    "            // Noise words cannot be context words\n",
    "            if (!contexts.contains(neg)) {\n",
    "                negatives.add(neg);\n",
    "            }\n",
    "        }\n",
    "        allNegatives.add(negatives);\n",
    "    }\n",
    "    return allNegatives;\n",
    "}\n",
    "\n",
    "ArrayList<ArrayList<Integer>> allNegatives = getNegatives(allContexts, corpus, 5);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小批量加载训练实例\n",
    ":label:`subsec_word2vec-minibatch-loading`\n",
    "\n",
    "在提取所有中心词及其上下文词和采样噪声词后，将它们转换成小批量的样本，在训练过程中可以迭代加载。\n",
    "\n",
    "在小批量中，$i^\\mathrm{th}$个样本包括中心词及其$n_i$个上下文词和$m_i$个噪声词。由于上下文窗口大小不同，$n_i+m_i$对于不同的$i$是不同的。因此，对于每个样本，我们在`contexts_negatives`个变量中将其上下文词和噪声词连结起来，并填充零，直到连结长度达到$\\max_i n_i+m_i$(`max_len`)。为了在计算损失时排除填充，我们定义了掩码变量`masks`。在`masks`中的元素和`contexts_negatives`中的元素之间存在一一对应关系，其中`masks`中的0（否则为1）对应于`contexts_negatives`中的填充。\n",
    "\n",
    "为了区分正反例，我们在`contexts_negatives`中通过一个`labels`变量将上下文词与噪声词分开。类似于`masks`，在`labels`中的元素和`contexts_negatives`中的元素之间也存在一一对应关系，其中`labels`中的1（否则为0）对应于`contexts_negatives`中的上下文词的正例。\n",
    "\n",
    "上述思想在下面的`batchify`函数中实现。其输入`data`是长度等于批量大小的列表，其中每个元素是由中心词`center`、其上下文词`context`和其噪声词`negative`组成的样本。此函数返回一个可以在训练期间加载用于计算的小批量，例如包括掩码变量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList batchifyData(NDList[] data) {\n",
    "    NDList centers = new NDList();\n",
    "    NDList contextsNegatives = new NDList();\n",
    "    NDList masks = new NDList();\n",
    "    NDList labels = new NDList();\n",
    "\n",
    "    long maxLen = 0;\n",
    "    for (NDList ndList : data) { // center, context, negative = ndList\n",
    "        maxLen =\n",
    "                Math.max(\n",
    "                        maxLen,\n",
    "                        ndList.get(1).countNonzero().getLong()\n",
    "                                + ndList.get(2).countNonzero().getLong());\n",
    "    }\n",
    "    for (NDList ndList : data) { // center, context, negative = ndList\n",
    "        NDArray center = ndList.get(0);\n",
    "        NDArray context = ndList.get(1);\n",
    "        NDArray negative = ndList.get(2);\n",
    "\n",
    "        int count = 0;\n",
    "        for (int i = 0; i < context.size(); i++) {\n",
    "            // If a 0 is found, we want to stop adding these\n",
    "            // values to NDArray\n",
    "            if (context.get(i).getInt() == 0) {\n",
    "                break;\n",
    "            }\n",
    "            contextsNegatives.add(context.get(i).reshape(1));\n",
    "            masks.add(manager.create(1).reshape(1));\n",
    "            labels.add(manager.create(1).reshape(1));\n",
    "            count += 1;\n",
    "        }\n",
    "        for (int i = 0; i < negative.size(); i++) {\n",
    "            // If a 0 is found, we want to stop adding these\n",
    "            // values to NDArray\n",
    "            if (negative.get(i).getInt() == 0) {\n",
    "                break;\n",
    "            }\n",
    "            contextsNegatives.add(negative.get(i).reshape(1));\n",
    "            masks.add(manager.create(1).reshape(1));\n",
    "            labels.add(manager.create(0).reshape(1));\n",
    "            count += 1;\n",
    "        }\n",
    "        // Fill with zeroes remaining array\n",
    "        while (count != maxLen) {\n",
    "            contextsNegatives.add(manager.create(0).reshape(1));\n",
    "            masks.add(manager.create(0).reshape(1));\n",
    "            labels.add(manager.create(0).reshape(1));\n",
    "            count += 1;\n",
    "        }\n",
    "\n",
    "        // Add this NDArrays to output NDArrays\n",
    "        centers.add(center.reshape(1));\n",
    "    }\n",
    "    return new NDList(\n",
    "            NDArrays.concat(centers).reshape(data.length, -1),\n",
    "            NDArrays.concat(contextsNegatives).reshape(data.length, -1),\n",
    "            NDArrays.concat(masks).reshape(data.length, -1),\n",
    "            NDArrays.concat(labels).reshape(data.length, -1));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们使用一个小批量的两个样本来测试此函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList x1 =\n",
    "        new NDList(\n",
    "                manager.create(new int[] {1}),\n",
    "                manager.create(new int[] {2, 2}),\n",
    "                manager.create(new int[] {3, 3, 3, 3}));\n",
    "NDList x2 =\n",
    "        new NDList(\n",
    "                manager.create(new int[] {1}),\n",
    "                manager.create(new int[] {2, 2, 2}),\n",
    "                manager.create(new int[] {3, 3}));\n",
    "\n",
    "NDList batchedData = batchifyData(new NDList[] {x1, x2});\n",
    "String[] names = new String[] {\"centers\", \"contexts_negatives\", \"masks\", \"labels\"};\n",
    "for (int i = 0; i < batchedData.size(); i++) {\n",
    "    System.out.println(names[i] + \" shape: \" + batchedData.get(i));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 整合代码\n",
    "\n",
    "最后，我们定义了读取PTB数据集并返回数据迭代器和词表的`load_data_ptb`函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList convertNDArray(Object[] data, NDManager manager) {\n",
    "    ArrayList<Integer> centers = (ArrayList<Integer>) data[0];\n",
    "    ArrayList<ArrayList<Integer>> contexts = (ArrayList<ArrayList<Integer>>) data[1];\n",
    "    ArrayList<ArrayList<Integer>> negatives = (ArrayList<ArrayList<Integer>>) data[2];\n",
    "\n",
    "    // Create centers NDArray\n",
    "    NDArray centersNDArray = manager.create(centers.stream().mapToInt(i -> i).toArray());\n",
    "\n",
    "    // Create contexts NDArray\n",
    "    int maxLen = 0;\n",
    "    for (ArrayList<Integer> context : contexts) {\n",
    "        maxLen = Math.max(maxLen, context.size());\n",
    "    }\n",
    "    // Fill arrays with 0s to all have same lengths and be able to create NDArray\n",
    "    for (ArrayList<Integer> context : contexts) {\n",
    "        while (context.size() != maxLen) {\n",
    "            context.add(0);\n",
    "        }\n",
    "    }\n",
    "    NDArray contextsNDArray =\n",
    "            manager.create(\n",
    "                    contexts.stream()\n",
    "                            .map(u -> u.stream().mapToInt(i -> i).toArray())\n",
    "                            .toArray(int[][]::new));\n",
    "\n",
    "    // Create negatives NDArray\n",
    "    maxLen = 0;\n",
    "    for (ArrayList<Integer> negative : negatives) {\n",
    "        maxLen = Math.max(maxLen, negative.size());\n",
    "    }\n",
    "    // Fill arrays with 0s to all have same lengths and be able to create NDArray\n",
    "    for (ArrayList<Integer> negative : negatives) {\n",
    "        while (negative.size() != maxLen) {\n",
    "            negative.add(0);\n",
    "        }\n",
    "    }\n",
    "    NDArray negativesNDArray =\n",
    "            manager.create(\n",
    "                    negatives.stream()\n",
    "                            .map(u -> u.stream().mapToInt(i -> i).toArray())\n",
    "                            .toArray(int[][]::new));\n",
    "\n",
    "    return new NDList(centersNDArray, contextsNDArray, negativesNDArray);\n",
    "}\n",
    "\n",
    "public static Pair<ArrayDataset, Vocab> loadDataPTB(\n",
    "        int batchSize, int maxWindowSize, int numNoiseWords, NDManager manager)\n",
    "        throws IOException, TranslateException {\n",
    "    String[][] sentences = readPTB();\n",
    "    Vocab vocab = new Vocab(sentences, 10, new String[] {});\n",
    "    String[][] subSampled = subSampling(sentences, vocab);\n",
    "    Integer[][] corpus = new Integer[subSampled.length][];\n",
    "    for (int i = 0; i < subSampled.length; i++) {\n",
    "        corpus[i] = vocab.getIdxs(subSampled[i]);\n",
    "    }\n",
    "    Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> pair =\n",
    "            getCentersAndContext(corpus, maxWindowSize);\n",
    "    ArrayList<ArrayList<Integer>> negatives =\n",
    "            getNegatives(pair.getValue(), corpus, numNoiseWords);\n",
    "\n",
    "    NDList ndArrays =\n",
    "            convertNDArray(new Object[] {pair.getKey(), pair.getValue(), negatives}, manager);\n",
    "    ArrayDataset dataset =\n",
    "            new ArrayDataset.Builder()\n",
    "                    .setData(ndArrays.get(0), ndArrays.get(1), ndArrays.get(2))\n",
    "                    .optDataBatchifier(\n",
    "                            new Batchifier() {\n",
    "                                @Override\n",
    "                                public NDList batchify(NDList[] ndLists) {\n",
    "                                    return batchifyData(ndLists);\n",
    "                                }\n",
    "\n",
    "                                @Override\n",
    "                                public NDList[] unbatchify(NDList ndList) {\n",
    "                                    return new NDList[0];\n",
    "                                }\n",
    "                            })\n",
    "                    .setSampling(batchSize, true)\n",
    "                    .build();\n",
    "\n",
    "    return new Pair<>(dataset, vocab);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们打印数据迭代器的第一个小批量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Pair<ArrayDataset, Vocab> datasetVocab = loadDataPTB(512, 5, 5, manager);\n",
    "ArrayDataset dataset = datasetVocab.getKey();\n",
    "vocab = datasetVocab.getValue();\n",
    "\n",
    "Batch batch = dataset.getData(manager).iterator().next();\n",
    "for (int i = 0; i < batch.getData().size(); i++) {\n",
    "    System.out.println(names[i] + \" shape: \" + batch.getData().get(i).getShape());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 高频词在训练中可能不是那么有用。我们可以对他们进行下采样，以便在训练中加快速度。\n",
    "* 为了提高计算效率，我们以小批量方式加载样本。我们可以定义其他变量来区分填充标记和非填充标记，以及正例和负例。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 如果不使用下采样，本节中代码的运行时间会发生什么变化？\n",
    "1. `RandomGenerator`类缓存`k`个随机采样结果。将`k`设置为其他值，看看它如何影响数据加载速度。\n",
    "1. 本节代码中的哪些其他超参数可能会影响数据加载速度？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
